import os

from data import common
from data import srdata

import numpy as np
import scipy.misc as misc

import torch
import torch.utils.data as data

class DIV2K(srdata.SRData):
    def __init__(self, args, train=True):
        super(DIV2K, self).__init__(args, train)
        self.repeat = 1#args.test_every // (args.n_train // args.batch_size)

    def _scan(self):
        # list_hr = []
        # for entry in os.scandir(self.dir_hr):
        #     filename = os.path.splitext(entry.name)[0]
        #     list_hr.append(os.path.join(self.dir_hr, filename + self.ext))
        # list_hr.sort()
        #     list_hr = []
        if self.train:
            list_hr = [i for i in range(self.num)]
        else:
            list_hr = []
            # list_hr = []
            # # list_lr = [[] for _ in self.scale]
            for entry in os.scandir(self.dir_hr):
                filename = os.path.splitext(entry.name)[0]
                list_hr.append(os.path.join(self.dir_hr, filename + self.ext))
            # # for si, s in enumerate(self.scale):
            #     #     list_lr[si].append(os.path.join(
            #     #         self.dir_lr,
            #     #         'X{}/{}x{}{}'.format(s, filename, s, self.ext)
            #     #     ))
            #
            list_hr.sort()
    # #     list_hr = []
    #     # list_lr = [[] for _ in self.scale]
    #     if self.train:
    #         idx_begin = 0
    #         idx_end = self.args.n_train
    #     else:
    #         idx_begin = self.args.n_train
    #         idx_end = self.args.offset_val + self.args.n_val
    #
    #     for i in range(idx_begin + 1, idx_end + 1):
    #         filename = '{:0>4}'.format(i)
    #         list_hr.append(os.path.join(self.dir_hr, filename + self.ext))
    #         # for si, s in enumerate(self.scale):
    #         #     list_lr[si].append(os.path.join(
    #         #         self.dir_lr,
    #         #         'X{}/{}x{}{}'.format(s, filename, s, self.ext)
    #         #     ))
    #
        return list_hr#[i for i in range(self.num)]#, list_lr

    # def _set_filesystem(self, dir_data):
    #     self.apath = dir_data #+ '/DIV2K'
    #     self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR')
    #     # self.dir_lr = os.path.join(self.apath, 'DIV2K_train_LR_bicubic')
    #     self.ext = '.png'
    #
    # def _name_hrbin(self):
    #     return os.path.join(
    #         self.apath,
    #         'bin',
    #         '{}_bin_HR.npy'.format(self.split)
    #     )

    # def _name_lrbin(self, scale):
    #     return os.path.join(
    #         self.apath,
    #         'bin',
    #         '{}_bin_LR_X{}.npy'.format(self.split, scale)
    #     )

    def __len__(self):
        if self.train:
            return len(self.images_hr) * self.repeat
        else:
            return len(self.images_hr)

    def _get_index(self, idx):
        if self.train:
            return idx % len(self.images_hr)
        else:
            return idx

